#include "eigs.h"
#include "../sort.h"
#include "../IGL_ASSERT.h"
#include "../placeholders.h"
#include <Spectra/SymGEigsShiftSolver.h>

template <
  typename EigsScalar,
  typename DerivedU,
  typename DerivedS,
  typename Solver>
IGL_INLINE bool igl::spectra::eigs(
  const Eigen::SparseMatrix<EigsScalar> & A,
  const Eigen::SparseMatrix<EigsScalar> & B,
  const int k,
  const igl::EigsType type,
  Eigen::PlainObjectBase<DerivedU> & U,
  Eigen::PlainObjectBase<DerivedS> & S)
{
  IGL_ASSERT(k > 0 && "k should be positive");
  IGL_ASSERT(k < A.rows() && "k should be less than size of A");
  IGL_ASSERT(type == igl::EIGS_TYPE_SM && "Only SM supported");
  // This seems like a hack. For the "eigs: grid" test this is necessary to get
  // at least 1e-4 error for the first 5 eigen values. It's annoying that this
  // means that the zero modes become O(sigma) and this is now rather large.
  //
  // I wonder if this is an issue with SparseLU and if UMFPACK would be better.
  //
  // Ideally this value would be 0.
  const EigsScalar sigma = 1e-8;
  return igl::spectra::eigs(A,B,k,sigma,U,S);
}
template <
  typename EigsScalar,
  typename DerivedU,
  typename DerivedS,
  typename Solver>
IGL_INLINE bool igl::spectra::eigs(
  const Eigen::SparseMatrix<EigsScalar> & A,
  const Eigen::SparseMatrix<EigsScalar> & B,
  const int k,
  const EigsScalar sigma,
  Eigen::PlainObjectBase<DerivedU> & U,
  Eigen::PlainObjectBase<DerivedS> & S)
{

  IGL_ASSERT(k > 0 && "k should be positive");
  IGL_ASSERT(k < A.rows() && "k should be less than size of A");

  class SparseMatProd
  {
    public:
      using Scalar = EigsScalar;
      const Eigen::SparseMatrix<Scalar> & m_B;
      SparseMatProd(const Eigen::SparseMatrix<Scalar> & B) : m_B(B) {}
      int rows() const { return m_B.rows(); }
      int cols() const { return m_B.cols(); }
      void perform_op(const Scalar *x_in, Scalar *y_out) const
      {
        typedef Eigen::Matrix<Scalar, Eigen::Dynamic, 1> VectorXS;
        Eigen::Map<const VectorXS> x(x_in, m_B.cols());
        Eigen::Map<      VectorXS> y(y_out, m_B.rows());
        y = m_B * x;
      }
  };

  // Solver must expose .compute(A) and .solve(x)
  class ShiftInvert
  {
  public:
    using Scalar = EigsScalar;
  private:
    const Eigen::SparseMatrix<Scalar> & m_A;
    const Eigen::SparseMatrix<Scalar> & m_B;
    Scalar m_sigma;
    Solver m_solver;
  public:
    bool m_solver_is_successfully_factorized;
    ShiftInvert(
        const Eigen::SparseMatrix<Scalar>& A, 
        const Eigen::SparseMatrix<Scalar>& B, 
        const Scalar sigma):
        m_A(A), m_B(B)
    {
      IGL_ASSERT(m_A.rows() == m_A.cols() && "A must be square");
      IGL_ASSERT(m_B.rows() == m_B.cols() && "B must be square");
      IGL_ASSERT(m_A.rows() == m_B.cols() && "A and B must have the same size");
      set_shift(sigma, true);
    }
    void set_shift(const Scalar & sigma, const bool force = false)
    {
      if(sigma == m_sigma && !force)
      {
        return;
      }
      m_sigma = sigma;
      const Eigen::SparseMatrix<Scalar> C = m_A + m_sigma * m_B;
      m_solver.compute(C);
      m_solver_is_successfully_factorized = (m_solver.info() == Eigen::Success);
    }
    int rows() const { return m_A.rows(); }
    int cols() const { return m_A.cols(); }
    void perform_op(const Scalar* x_in,Scalar* y_out) const
    {
      typedef Eigen::Matrix<Scalar, Eigen::Dynamic, 1> VectorXS;
      Eigen::Map<const VectorXS>x(x_in, m_A.cols());
      Eigen::Map<VectorXS>y(y_out, m_A.rows());
      y = m_solver.solve(x);
    }
  };

  SparseMatProd Bop(B);
  ShiftInvert op(A, B, sigma);
  if(!op.m_solver_is_successfully_factorized)
  {
    return false;
  }
  Spectra::SymGEigsShiftSolver<ShiftInvert, SparseMatProd, Spectra::GEigsMode::ShiftInvert> geigs(op, Bop, k, 2*k, sigma);

  geigs.init();
  geigs.compute(Spectra::SortRule::LargestMagn);
  if (geigs.info() != Spectra::CompInfo::Successful)
  {
    return false;
  }
  U = geigs.eigenvectors().template cast<typename DerivedU::Scalar>();
  S = geigs.eigenvalues().template cast<typename DerivedS::Scalar>();

  Eigen::VectorXi I;
  igl::sort( Eigen::VectorXd(S), 1, false, S, I);
  U = U(igl::placeholders::all,I).eval();
  return true;
}

#ifdef IGL_STATIC_LIBRARY
// Explicit template instantiation
// generated by autoexplicit.sh
template bool igl::spectra::eigs<double, Eigen::Matrix<double, -1, -1, 0, -1, -1>, Eigen::Matrix<double, -1, 1, 0, -1, 1>, Eigen::SparseLU<Eigen::SparseMatrix<double, 0, int>, Eigen::COLAMDOrdering<int> > >(Eigen::SparseMatrix<double, 0, int> const&, Eigen::SparseMatrix<double, 0, int> const&, int , igl::EigsType, Eigen::PlainObjectBase<Eigen::Matrix<double, -1, -1, 0, -1, -1> >&, Eigen::PlainObjectBase<Eigen::Matrix<double, -1, 1, 0, -1, 1> >&);
#endif
